package com.axinfu.util.encrypt;

import com.axinfu.util.EmptyUtil;

import javax.crypto.Cipher;
import java.io.ByteArrayOutputStream;
import java.math.BigInteger;
import java.security.Key;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.security.Signature;
import java.security.interfaces.RSAKey;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.RSAPrivateKeySpec;
import java.security.spec.RSAPublicKeySpec;
import java.security.spec.X509EncodedKeySpec;

/**
 * 带RSA的加密解密组件
 * RSA加密解密
 *
 * @author zjn
 * @since 2022/3/23
 */
@SuppressWarnings("unused")
public class EncryptionRsa {

    public static final String ENCODING_UTF8 = "UTF-8";

    public static final String ALGORITHM_RSA = "RSA";
    public static final String SIGN_ALGORITHMS_MD5 = "MD5WithRSA";
    public static final String SIGN_ALGORITHMS_SHA1 = "SHA1WithRSA";

    private EncryptionRsa() {

    }

    public static class EncryptionKeyRsa {
        private final BigInteger modulus;
        private final String modulusHex;
        private final BigInteger publicExponent;
        private final String publicExponentHex;
        private final BigInteger privateExponent;
        private final String privateExponentHex;
        private byte[] publicKey;
        private String publicKeyHex;
        private String publicKeyBase64;
        private byte[] privateKey;
        private String privateKeyHex;
        private String privateKeyBase64;

        public EncryptionKeyRsa(BigInteger modulus, BigInteger publicExponent, BigInteger privateExponent) {
            this.modulus = modulus;
            this.modulusHex = modulus.toString(16).toUpperCase();
            this.publicExponent = publicExponent;
            this.publicExponentHex = publicExponent.toString(16).toUpperCase();
            this.privateExponent = privateExponent;
            this.privateExponentHex = privateExponent.toString(16).toUpperCase();
        }

        public BigInteger getModulus() {
            return modulus;
        }

        public String getModulusHex() {
            return modulusHex;
        }

        public BigInteger getPublicExponent() {
            return publicExponent;
        }

        public String getPublicExponentHex() {
            return publicExponentHex;
        }

        public BigInteger getPrivateExponent() {
            return privateExponent;
        }

        public String getPrivateExponentHex() {
            return privateExponentHex;
        }

        public byte[] getPublicKey() {
            return publicKey;
        }

        public void setPublicKey(byte[] publicKey) {
            this.publicKey = publicKey;
            this.publicKeyHex = Encryption.bytesToHex(publicKey);
            this.publicKeyBase64 = Encryption.encryptBase64ToString(publicKey);
        }

        public String getPublicKeyHex() {
            return publicKeyHex;
        }

        public String getPublicKeyBase64() {
            return publicKeyBase64;
        }

        public byte[] getPrivateKey() {
            return privateKey;
        }

        public void setPrivateKey(byte[] privateKey) {
            this.privateKey = privateKey;
            this.privateKeyHex = Encryption.bytesToHex(privateKey);
            this.privateKeyBase64 = Encryption.encryptBase64ToString(privateKey);
        }

        public String getPrivateKeyHex() {
            return privateKeyHex;
        }

        public String getPrivateKeyBase64() {
            return privateKeyBase64;
        }
    }


    /**
     * 生成RSA秘钥
     *
     * @param keySize 秘钥长度
     * @param seed    种子
     * @return RSA秘钥信息
     */
    public static EncryptionKeyRsa generatorRsaKey(int keySize, byte[] seed) {
        try {
            KeyPairGenerator keyPairGen = KeyPairGenerator.getInstance(ALGORITHM_RSA);
            SecureRandom secureRandom;
            if (EmptyUtil.isNotEmpty(seed)) {
                secureRandom = new SecureRandom(seed);
            } else {
                secureRandom = new SecureRandom();
            }
            keyPairGen.initialize(keySize, secureRandom);
            KeyPair keyPair = keyPairGen.generateKeyPair();
            RSAPublicKey publicKey = (RSAPublicKey) keyPair.getPublic();
            RSAPrivateKey privateKey = (RSAPrivateKey) keyPair.getPrivate();

            EncryptionKeyRsa encryptionKeyRsa = new EncryptionKeyRsa(publicKey.getModulus(),
                    publicKey.getPublicExponent(), privateKey.getPrivateExponent());
            encryptionKeyRsa.setPublicKey(publicKey.getEncoded());
            encryptionKeyRsa.setPrivateKey(privateKey.getEncoded());

            return encryptionKeyRsa;
        } catch (NoSuchAlgorithmException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 生成RSA秘钥
     *
     * @param keySize 秘钥长度
     * @return RSA秘钥信息
     */
    public static EncryptionKeyRsa generatorRsaKey(int keySize) {
        return generatorRsaKey(keySize, null);
    }

    /**
     * 使用模和指数生成RSA公钥
     * 注意：【此代码用了默认补位方式，为RSA/None/PKCS1Padding，不同JDK默认的补位方式可能不同，如Android默认是RSA/None/NoPadding】
     *
     * @param modulus  模
     * @param exponent 指数
     * @return 公钥
     */
    public static RSAPublicKey parseRsaPublicKey(BigInteger modulus, BigInteger exponent) {
        try {
            KeyFactory keyFactory = KeyFactory.getInstance(ALGORITHM_RSA);
            RSAPublicKeySpec keySpec = new RSAPublicKeySpec(modulus, exponent);
            return (RSAPublicKey) keyFactory.generatePublic(keySpec);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 将十六进制公钥字符串转换为公钥
     *
     * @param publicKeyHex 十六进制公钥字符串
     * @return 公钥
     */
    public static RSAPublicKey parseRsaPublicKeyFromHex(String publicKeyHex) {
        try {
            byte[] buffer = Encryption.hexToBytes(publicKeyHex);
            KeyFactory keyFactory = KeyFactory.getInstance(ALGORITHM_RSA);
            X509EncodedKeySpec keySpec = new X509EncodedKeySpec(buffer);
            return (RSAPublicKey) keyFactory.generatePublic(keySpec);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 将base64公钥字符串转换为公钥
     *
     * @param publicKeyBase64 base64公钥字符串
     * @return 公钥
     */
    public static RSAPublicKey parseRsaPublicKeyFromBase64(String publicKeyBase64) {
        try {
            byte[] buffer = Encryption.decryptBase64(publicKeyBase64);
            KeyFactory keyFactory = KeyFactory.getInstance(ALGORITHM_RSA);
            X509EncodedKeySpec keySpec = new X509EncodedKeySpec(buffer);
            return (RSAPublicKey) keyFactory.generatePublic(keySpec);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 使用模和指数生成RSA私钥
     * 注意：【此代码用了默认补位方式，为RSA/None/PKCS1Padding，不同JDK默认的补位方式可能不同，如Android默认是RSA
     * /None/NoPadding】
     *
     * @param modulus  模
     * @param exponent 指数
     * @return 私钥
     */
    public static RSAPrivateKey parseRsaPrivateKey(BigInteger modulus, BigInteger exponent) {
        try {
            KeyFactory keyFactory = KeyFactory.getInstance(ALGORITHM_RSA);
            RSAPrivateKeySpec keySpec = new RSAPrivateKeySpec(modulus, exponent);
            return (RSAPrivateKey) keyFactory.generatePrivate(keySpec);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 将十六进制私钥字符串转换为私钥
     *
     * @param primaryKeyHex 十六进制私钥字符串
     * @return 私钥
     */
    public static RSAPrivateKey parseRsaPrivateKeyFromHex(String primaryKeyHex) {
        try {
            byte[] buffer = Encryption.hexToBytes(primaryKeyHex);
            PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(buffer);
            KeyFactory keyFactory = KeyFactory.getInstance(ALGORITHM_RSA);
            return (RSAPrivateKey) keyFactory.generatePrivate(keySpec);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 将base64私钥字符串转换为私钥
     *
     * @param primaryKeyHex base64私钥字符串
     * @return 私钥
     */
    public static RSAPrivateKey parseRsaPrivateKeyFromBase64(String primaryKeyHex) {
        try {
            byte[] buffer = Encryption.decryptBase64(primaryKeyHex);
            PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(buffer);
            KeyFactory keyFactory = KeyFactory.getInstance(ALGORITHM_RSA);
            return (RSAPrivateKey) keyFactory.generatePrivate(keySpec);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 加密
     *
     * @param key  秘钥
     * @param data 待加密数据
     * @return 加密后数据
     */
    public static byte[] encryptRsa(Key key, byte[] data) {
        try {
            Cipher cipher = Cipher.getInstance(ALGORITHM_RSA);
            cipher.init(Cipher.ENCRYPT_MODE, key);

            int splitLength = ((RSAKey) key).getModulus().bitLength() / 8 - 11;
            byte[][] arrays = splitBytes(data, splitLength);
            ByteArrayOutputStream out = new ByteArrayOutputStream();
            for (byte[] array : arrays) {
                byte[] cache = cipher.doFinal(array);
                out.write(cache, 0, cache.length);
            }
            byte[] encryptedData = out.toByteArray();
            out.close();
            return encryptedData;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 加密
     *
     * @param key  秘钥
     * @param data 待加密数据
     * @return 加密后数据
     */
    public static String encryptRsaToHex(Key key, byte[] data) {
        return Encryption.bytesToHex(encryptRsa(key, data));
    }

    /**
     * 加密
     *
     * @param key      秘钥
     * @param data     待加密数据
     * @param encoding 待加密数据编码
     * @return 加密后数据
     */
    public static String encryptRsaToHex(Key key, String data, String encoding) {
        try {
            return encryptRsaToHex(key, data.getBytes(encoding));
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 加密
     *
     * @param key  秘钥
     * @param data 待加密数据
     * @return 加密后数据
     */
    public static String encryptRsaToHex(Key key, String data) {
        try {
            return encryptRsaToHex(key, data, ENCODING_UTF8);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 加密
     *
     * @param key  秘钥
     * @param data 待加密数据
     * @return 加密后数据
     */
    public static String encryptRsaToBase64(Key key, byte[] data) {
        return Encryption.encryptBase64ToString(encryptRsa(key, data));
    }

    /**
     * 加密
     *
     * @param key      秘钥
     * @param data     待加密数据
     * @param encoding 待加密数据编码
     * @return 加密后数据
     */
    public static String encryptRsaToBase64(Key key, String data, String encoding) {
        try {
            return encryptRsaToBase64(key, data.getBytes(encoding));
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 加密
     *
     * @param key  秘钥
     * @param data 待加密数据
     * @return 加密后数据
     */
    public static String encryptRsaToBase64(Key key, String data) {
        try {
            return encryptRsaToBase64(key, data, ENCODING_UTF8);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 解密
     *
     * @param key  秘钥
     * @param data 待解密数据
     * @return 解密数据
     */
    public static byte[] decryptRsa(Key key, byte[] data) {
        try {
            Cipher cipher = Cipher.getInstance(ALGORITHM_RSA);
            cipher.init(Cipher.DECRYPT_MODE, key);

            ByteArrayOutputStream out = new ByteArrayOutputStream();
            int splitLength = ((RSAKey) key).getModulus().bitLength() / 8;
            byte[][] arrays = splitBytes(data, splitLength);
            for (byte[] array : arrays) {
                byte[] cache = cipher.doFinal(array);
                out.write(cache, 0, cache.length);
            }
            byte[] encryptedData = out.toByteArray();
            out.close();
            return encryptedData;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 解密
     *
     * @param key  秘钥
     * @param data 待解密数据
     * @return 解密数据
     */
    public static byte[] decryptRsaFromHex(Key key, String data) {
        try {
            return decryptRsa(key, Encryption.hexToBytes(data));
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 解密
     *
     * @param key      秘钥
     * @param data     待解密数据
     * @param encoding 解密后数据编码
     * @return 解密数据
     */
    public static String decryptRsaFromHexToString(Key key, String data, String encoding) {
        try {
            return new String(decryptRsaFromHex(key, data), encoding);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 解密
     *
     * @param key  秘钥
     * @param data 待解密数据
     * @return 解密数据
     */
    public static String decryptRsaFromHexToString(Key key, String data) {
        return decryptRsaFromHexToString(key, data, ENCODING_UTF8);
    }

    /**
     * 解密
     *
     * @param key  秘钥
     * @param data 待解密数据
     * @return 解密数据
     */
    public static byte[] decryptRsaFromBase64(Key key, String data) {
        try {
            return decryptRsa(key, Encryption.decryptBase64(data));
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 解密
     *
     * @param key      秘钥
     * @param data     待解密数据
     * @param encoding 解密后数据编码
     * @return 解密数据
     */
    public static String decryptRsaFromBase64ToString(Key key, String data, String encoding) {
        try {
            return new String(decryptRsaFromBase64(key, data), encoding);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 解密
     *
     * @param key  秘钥
     * @param data 待解密数据
     * @return 解密数据
     */
    public static String decryptRsaFromBase64ToString(Key key, String data) {
        return decryptRsaFromBase64ToString(key, data, ENCODING_UTF8);
    }

    /**
     * RSA签名
     *
     * @param signAlgorithms 签名算法
     * @param privateKey     私钥
     * @param data           待签名数据
     * @return 签名值
     */
    public static byte[] signRsa(String signAlgorithms, PrivateKey privateKey, byte[] data) {
        try {
            Signature signature = Signature.getInstance(signAlgorithms);
            signature.initSign(privateKey);
            signature.update(data);
            return signature.sign();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * RSA签名
     *
     * @param signAlgorithms 签名算法
     * @param privateKey     私钥
     * @param data           待签名数据
     * @return 签名值
     */
    public static String signRsaToHex(String signAlgorithms, PrivateKey privateKey, byte[] data) {
        return Encryption.bytesToHex(signRsa(signAlgorithms, privateKey, data));
    }

    /**
     * RSA签名
     *
     * @param signAlgorithms 签名算法
     * @param privateKey     私钥
     * @param data           待签名数据
     * @param encoding       待签名数据编码
     * @return 签名值
     */
    public static String signRsaToHex(String signAlgorithms, PrivateKey privateKey, String data, String encoding) {
        try {
            return signRsaToHex(signAlgorithms, privateKey, data.getBytes(encoding));
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * RSA签名
     *
     * @param signAlgorithms 签名算法
     * @param privateKey     私钥
     * @param data           待签名数据
     * @return 签名值
     */
    public static String signRsaToHex(String signAlgorithms, PrivateKey privateKey, String data) {
        return signRsaToHex(signAlgorithms, privateKey, data, ENCODING_UTF8);
    }

    /**
     * RSA签名
     *
     * @param signAlgorithms 签名算法
     * @param privateKey     私钥
     * @param data           待签名数据
     * @return 签名值
     */
    public static String signRsaToBase64(String signAlgorithms, PrivateKey privateKey, byte[] data) {
        return Encryption.encryptBase64ToString(signRsa(signAlgorithms, privateKey, data));
    }

    /**
     * RSA签名
     *
     * @param signAlgorithms 签名算法
     * @param privateKey     私钥
     * @param data           待签名数据
     * @param encoding       待签名数据编码
     * @return 签名值
     */
    public static String signRsaToBase64(String signAlgorithms, PrivateKey privateKey, String data, String encoding) {
        try {
            return signRsaToBase64(signAlgorithms, privateKey, data.getBytes(encoding));
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * RSA签名
     *
     * @param signAlgorithms 签名算法
     * @param privateKey     私钥
     * @param data           待签名数据
     * @return 签名值
     */
    public static String signRsaToBase64(String signAlgorithms, PrivateKey privateKey, String data) {
        return signRsaToBase64(signAlgorithms, privateKey, data, ENCODING_UTF8);
    }

    /**
     * RSA验签名验证
     *
     * @param signAlgorithms 签名算法
     * @param publicKey      公钥
     * @param data           待签名数据
     * @param sign           签名值
     * @return 是否通过验证
     */
    public static boolean verifyRsa(String signAlgorithms, PublicKey publicKey, byte[] data, byte[] sign) {
        try {
            Signature signature = Signature.getInstance(signAlgorithms);
            signature.initVerify(publicKey);
            signature.update(data);
            return signature.verify(sign);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * RSA验签名验证
     *
     * @param signAlgorithms 签名算法
     * @param publicKey      公钥
     * @param data           待签名数据
     * @param signHex        十六进制签名值
     * @return 是否通过验证
     */
    public static boolean verifyRsaFromHex(String signAlgorithms, PublicKey publicKey, byte[] data, String signHex) {
        return verifyRsa(signAlgorithms, publicKey, data, Encryption.hexToBytes(signHex));
    }

    /**
     * RSA验签名验证
     *
     * @param signAlgorithms 签名算法
     * @param publicKey      公钥
     * @param data           待加密数据
     * @param encoding       待加密数据编码
     * @param signHex        十六进制签名值
     * @return 是否通过验证
     */
    public static boolean verifyRsaFromHex(String signAlgorithms, PublicKey publicKey, String data, String encoding,
                                           String signHex) {
        try {
            return verifyRsa(signAlgorithms, publicKey, data.getBytes(encoding), Encryption.hexToBytes(signHex));
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * RSA验签名验证
     *
     * @param signAlgorithms 签名算法
     * @param publicKey      公钥
     * @param data           待加密数据
     * @param signHex        十六进制签名值
     * @return 是否通过验证
     */
    @SuppressWarnings("all")
    public static boolean verifyRsaFromHex(String signAlgorithms, PublicKey publicKey, String data, String signHex) {
        try {
            return verifyRsa(signAlgorithms, publicKey, data.getBytes(ENCODING_UTF8),
                    Encryption.hexToBytes(signHex));
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * RSA验签名验证
     *
     * @param signAlgorithms 签名算法
     * @param publicKey      公钥
     * @param data           待签名数据
     * @param signBase64     base64签名值
     * @return 是否通过验证
     */
    public static boolean verifyRsaFromBase64(String signAlgorithms, PublicKey publicKey, byte[] data,
                                              String signBase64) {
        return verifyRsa(signAlgorithms, publicKey, data, Encryption.decryptBase64(signBase64));
    }

    /**
     * RSA验签名验证
     *
     * @param signAlgorithms 签名算法
     * @param publicKey      公钥
     * @param data           待加密数据
     * @param encoding       待加密数据编码
     * @param signBase64     base64签名值
     * @return 是否通过验证
     */
    public static boolean verifyRsaFromBase64(String signAlgorithms, PublicKey publicKey, String data,
                                              String encoding, String signBase64) {
        try {
            return verifyRsa(signAlgorithms, publicKey, data.getBytes(encoding), Encryption.decryptBase64(signBase64));
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * RSA验签名验证
     *
     * @param signAlgorithms 签名算法
     * @param publicKey      公钥
     * @param data           待签名数据
     * @param signBase64     base64签名值
     * @return 是否通过验证
     */
    @SuppressWarnings("all")
    public static boolean verifyRsaFromBase64(String signAlgorithms, PublicKey publicKey, String data,
                                              String signBase64) {
        try {
            return verifyRsa(signAlgorithms, publicKey, data.getBytes(ENCODING_UTF8),
                    Encryption.decryptBase64(signBase64));
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 按照指定长度对内容进行分段
     *
     * @param bytes       原文
     * @param splitLength 分段长度
     * @return 分段内容
     */
    private static byte[][] splitBytes(byte[] bytes, int splitLength) {
        int x; //商，数据拆分的组数，余数不为0时+1
        int y; //余数
        y = bytes.length % splitLength;
        if (y != 0) {
            x = bytes.length / splitLength + 1;
        } else {
            x = bytes.length / splitLength;
        }
        byte[][] arrays = new byte[x][];
        byte[] array;
        for (int i = 0; i < x; i++) {
            if (i == x - 1 && bytes.length % splitLength != 0) {
                array = new byte[bytes.length % splitLength];
                System.arraycopy(bytes, i * splitLength, array, 0, bytes.length % splitLength);
            } else {
                array = new byte[splitLength];
                System.arraycopy(bytes, i * splitLength, array, 0, splitLength);
            }
            arrays[i] = array;
        }
        return arrays;
    }
}
